import json

import torch
import torch.distributed as dist

from vlmeval.config import supported_VLM
from vlmeval.dataset.video_dataset_config import supported_video_datasets
from vlmeval.dataset import build_dataset
from vlmeval.inference import infer_data_job
from vlmeval.inference_video import infer_data_job_video
from vlmeval.inference_mt import infer_data_job_mt
from vlmeval.smp import *
from vlmeval.utils.result_transfer import MMMU_result_transfer, MMTBench_result_transfer


def build_model_from_config(cfg, model_name):
    import vlmeval.api
    import vlmeval.vlm

    config = cp.deepcopy(cfg[model_name])
    if config == {}:
        return supported_VLM[model_name]()
    assert "class" in config
    cls_name = config.pop("class")
    if hasattr(vlmeval.api, cls_name):
        return getattr(vlmeval.api, cls_name)(**config)
    elif hasattr(vlmeval.vlm, cls_name):
        return getattr(vlmeval.vlm, cls_name)(**config)
    else:
        raise ValueError(
            f"Class {cls_name} is not supported in `vlmeval.api` or `vlmeval.vlm`"
        )


def build_dataset_from_config(cfg, dataset_name):
    import vlmeval.dataset
    import inspect

    config = cp.deepcopy(cfg[dataset_name])
    if config == {}:
        return supported_video_datasets[dataset_name]()
    assert "class" in config
    cls_name = config.pop("class")
    if hasattr(vlmeval.dataset, cls_name):
        cls = getattr(vlmeval.dataset, cls_name)
        sig = inspect.signature(cls.__init__)
        valid_params = {k: v for k, v in config.items() if k in sig.parameters}
        if cls.MODALITY == "VIDEO":
            if valid_params.get("fps", 0) > 0 and valid_params.get("nframe", 0) > 0:
                raise ValueError("fps and nframe should not be set at the same time")
            if valid_params.get("fps", 0) <= 0 and valid_params.get("nframe", 0) <= 0:
                raise ValueError(
                    "fps and nframe should be set at least one valid value"
                )
        return cls(**valid_params)
    else:
        raise ValueError(f"Class {cls_name} is not supported in `vlmeval.dataset`")


def parse_args():
    help_msg = """\
You can launch the evaluation by setting either --data and --model or --config.

--data and --model:
    Each Arg should be a list of strings, specifying the names of datasets and models.
    To find all supported model names, please refer to the `vlmeval/config.py` of check the output of the command \
        `vlmutil mlist all` in the terminal (you should first have vlmeval installed).
    To find all supported dataset names, please refer to the `vlmeval/dataset/__init__.py` file. The python script \
        to print all supported dataset names is as follows:
        ```python
        from vlmeval.dataset import SUPPORTED_DATASETS
        print(SUPPORTED_DATASETS)
        ```
        or you can check the output of the command `vlmutil dlist all` in the terminal.
    To find all supported video dataset default settings, please refer to the \
        `vlmeval/dataset/video_dataset_config.py` file.

--config:
    Launch the evaluation by specifying the path to the config json file. Sample Json Content:
    ```json
    {
        "model": {
            "GPT4o_20240806_T00_HIGH": {
                "class": "GPT4V",
                "model": "gpt-4o-2024-08-06",
                "temperature": 0,
                "img_detail": "high"
            },
            "GPT4o_20240806_T10_Low": {
                "class": "GPT4V",
                "model": "gpt-4o-2024-08-06",
                "temperature": 1.0,
                "img_detail": "low"
            },
            "GPT4o_20241120": {}
        },
        "data": {
            "MME-RealWorld-Lite": {
                "class": "MMERealWorld",
                "dataset": "MME-RealWorld-Lite"
            },
            "MMBench_DEV_EN_V11": {
                "class": "ImageMCQDataset",
                "dataset": "MMBench_DEV_EN_V11"
            },
            "MMBench_Video_8frame_nopack": {},
            "Video-MME_16frame_subs": {
                "class": "VideoMME",
                "dataset": "Video-MME",
                "nframe": 16,
                "use_subtitle": true,
            }
        }
    }
    ```
    Currently, only `model` and `data` are supported fields. The content of each field is a dictionary.
    For `model`, the key is the name of the model, and the value is a dictionary containing the following keys:
    - `class`: The class name of the model, which should be a class in `vlmeval.vlm` or `vlmeval.api`.
    - Other keys are specific to the model, please refer to the corresponding class.
    - Tip: The defined model in the `supported_VLM` of `vlmeval/config.py` can be used as a shortcut.
    For `data`, the key is the name of the dataset (should be the same as the `dataset` field in most cases, \
        except for video datasets), and the value is a dictionary containing the following keys:
    - `class`: The class name of the dataset, which should be a class in `vlmeval.dataset`.
    - `dataset`: The name of the dataset, which should be a string that is accepted by the `dataset` argument of the \
        corresponding class.
    - Other keys are specific to the dataset, please refer to the corresponding class.
    - Tip: The defined dataset in the `supported_video_datasets` of `vlmeval/dataset/video_dataset_config.py` \
        can be used as a shortcut.

    The keys in the `model` and `data` fields will be used for naming the prediction files and evaluation results.
    When launching with `--config`, args for API VLMs, such as `--retry`, `--verbose`, will be ignored.
"""
    parser = argparse.ArgumentParser(
        description=help_msg, formatter_class=argparse.RawTextHelpFormatter
    )
    # Essential Args, Setting the Names of Datasets and Models
    parser.add_argument("--data", type=str, nargs="+", help="Names of Datasets")
    parser.add_argument("--model", type=str, nargs="+", help="Names of Models")
    parser.add_argument("--config", type=str, help="Path to the Config Json File")
    # Work Dir
    parser.add_argument(
        "--work-dir", type=str, default="./outputs", help="select the output directory"
    )
    # Infer + Eval or Infer Only
    parser.add_argument("--mode", type=str, default="all", choices=["all", "infer"])
    # API Kwargs, Apply to API VLMs and Judge API LLMs
    parser.add_argument("--api-nproc", type=int, default=4, help="Parallel API calling")
    parser.add_argument(
        "--retry", type=int, default=None, help="retry numbers for API VLMs"
    )
    parser.add_argument(
        "--judge-args", type=str, default=None, help="Judge arguments in JSON format"
    )
    # Explicitly Set the Judge Model
    parser.add_argument("--judge", type=str, default=None)
    # Logging Utils
    parser.add_argument("--verbose", action="store_true")
    # Configuration for Resume
    # Ignore: will not rerun failed VLM inference
    parser.add_argument("--ignore", action="store_true", help="Ignore failed indices. ")
    # Reuse: will reuse the existing prediction files
    parser.add_argument("--reuse", action="store_true")
    # Reuse-aux: if set, when reuse is True, will also reuse the auxiliary evaluation files
    parser.add_argument(
        "--reuse-aux", type=bool, default=True, help="reuse auxiliary evaluation files"
    )

    args = parser.parse_args()
    return args


def main():
    logger = get_logger("RUN")
    rank, world_size = get_rank_and_world_size()
    args = parse_args()
    use_config, cfg = False, None
    if args.config is not None:
        assert (
            args.data is None and args.model is None
        ), "--data and --model should not be set when using --config"
        use_config, cfg = True, load(args.config)
        args.model = list(cfg["model"].keys())
        args.data = list(cfg["data"].keys())
    else:
        assert len(args.data), "--data should be a list of data files"

    if rank == 0:
        if not args.reuse:
            logger.warning(
                "--reuse is not set, will not reuse previous (before one day) temporary files"
            )
        else:
            logger.warning(
                "--reuse is set, will reuse the latest prediction & temporary pickle files"
            )

    if "MMEVAL_ROOT" in os.environ:
        args.work_dir = os.environ["MMEVAL_ROOT"]

    if not use_config:
        for k, v in supported_VLM.items():
            if (
                hasattr(v, "keywords")
                and "retry" in v.keywords
                and args.retry is not None
            ):
                v.keywords["retry"] = args.retry
                supported_VLM[k] = v
            if (
                hasattr(v, "keywords")
                and "verbose" in v.keywords
                and args.verbose is not None
            ):
                v.keywords["verbose"] = args.verbose
                supported_VLM[k] = v

    if world_size > 1:
        local_rank = os.environ.get("LOCAL_RANK", 0)
        torch.cuda.set_device(int(local_rank))
        dist.init_process_group(
            backend="nccl",
            timeout=datetime.timedelta(
                seconds=int(os.environ.get("DIST_TIMEOUT", 3600))
            ),
        )

    for _, model_name in enumerate(args.model):
        model = None
        date, commit_id = timestr("day"), githash(digits=8)
        eval_id = f"T{date}_G{commit_id}"

        pred_root = osp.join(args.work_dir, model_name, eval_id)
        pred_root_meta = osp.join(args.work_dir, model_name)
        os.makedirs(pred_root_meta, exist_ok=True)

        prev_pred_roots = ls(osp.join(args.work_dir, model_name), mode="dir")
        if len(prev_pred_roots) and args.reuse:
            prev_pred_roots.sort()

        if not osp.exists(pred_root):
            os.makedirs(pred_root, exist_ok=True)

        if use_config:
            model = build_model_from_config(cfg["model"], model_name)

        for _, dataset_name in enumerate(args.data):
            if world_size > 1:
                dist.barrier()

            try:
                result_file_base = f"{model_name}_{dataset_name}.xlsx"

                if use_config:
                    if world_size > 1:
                        if rank == 0:
                            dataset = build_dataset_from_config(
                                cfg["data"], dataset_name
                            )
                        dist.barrier()
                    dataset = build_dataset_from_config(cfg["data"], dataset_name)
                    if dataset is None:
                        logger.error(
                            f"Dataset {dataset_name} is not valid, will be skipped. "
                        )
                        continue
                else:
                    dataset_kwargs = {}
                    if dataset_name in [
                        "MMLongBench_DOC",
                        "DUDE",
                        "DUDE_MINI",
                        "SLIDEVQA",
                        "SLIDEVQA_MINI",
                    ]:
                        dataset_kwargs["model"] = model_name

                    # If distributed, first build the dataset on the main process for doing preparation works
                    if world_size > 1:
                        if rank == 0:
                            dataset = build_dataset(dataset_name, **dataset_kwargs)
                        dist.barrier()

                    dataset = build_dataset(dataset_name, **dataset_kwargs)
                    if dataset is None:
                        logger.error(
                            f"Dataset {dataset_name} is not valid, will be skipped. "
                        )
                        continue

                # Handling Multi-Turn Dataset
                if dataset.TYPE == "MT":
                    result_file_base = result_file_base.replace(".xlsx", ".tsv")

                result_file = osp.join(pred_root, result_file_base)

                # Reuse the previous prediction file if exists
                if rank == 0 and len(prev_pred_roots):
                    prev_result_files = []
                    prev_pkl_file_list = []
                    for root in prev_pred_roots[::-1]:
                        if osp.exists(osp.join(root, result_file_base)):
                            if args.reuse_aux:
                                prev_result_files = fetch_aux_files(
                                    osp.join(root, result_file_base)
                                )
                            else:
                                prev_result_files = [osp.join(root, result_file_base)]
                            break
                        elif commit_id in root and len(ls(root)) and root != pred_root:
                            temp_files = ls(root, match=[dataset_name, ".pkl"])
                            if len(temp_files):
                                prev_pkl_file_list.extend(temp_files)
                                break
                    if not args.reuse:
                        prev_result_files = []
                        prev_pkl_file_list = []
                    if len(prev_result_files):
                        for prev_result_file in prev_result_files:
                            src = prev_result_file
                            tgt = osp.join(pred_root, osp.basename(src))
                            if not osp.exists(tgt):
                                shutil.copy(src, tgt)
                                logger.info(
                                    f"--reuse is set, will reuse the prediction file {src}."
                                )
                            else:
                                logger.warning(f"File already exists: {tgt}")

                    elif len(prev_pkl_file_list):
                        for fname in prev_pkl_file_list:
                            target_path = osp.join(pred_root, osp.basename(fname))
                            if not osp.exists(target_path):
                                shutil.copy(fname, target_path)
                                logger.info(
                                    f"--reuse is set, will reuse the prediction pickle file {fname}."
                                )
                            else:
                                logger.warning(f"File already exists: {target_path}")

                if world_size > 1:
                    dist.barrier()

                if model is None:
                    model = model_name  # which is only a name

                # Perform the Inference
                if dataset.MODALITY == "VIDEO":
                    model = infer_data_job_video(
                        model,
                        work_dir=pred_root,
                        model_name=model_name,
                        dataset=dataset,
                        result_file_name=result_file_base,
                        verbose=args.verbose,
                        api_nproc=args.api_nproc,
                    )
                elif dataset.TYPE == "MT":
                    model = infer_data_job_mt(
                        model,
                        work_dir=pred_root,
                        model_name=model_name,
                        dataset=dataset,
                        verbose=args.verbose,
                        api_nproc=args.api_nproc,
                        ignore_failed=args.ignore,
                    )
                else:
                    model = infer_data_job(
                        model,
                        work_dir=pred_root,
                        model_name=model_name,
                        dataset=dataset,
                        verbose=args.verbose,
                        api_nproc=args.api_nproc,
                        ignore_failed=args.ignore,
                    )

                # Set the judge kwargs first before evaluation or dumping

                judge_kwargs = {
                    "nproc": args.api_nproc,
                    "verbose": args.verbose,
                    "retry": args.retry if args.retry is not None else 3,
                    **(json.loads(args.judge_args) if args.judge_args else {}),
                }

                if args.retry is not None:
                    judge_kwargs["retry"] = args.retry
                if args.judge is not None:
                    judge_kwargs["model"] = args.judge
                else:
                    if dataset.TYPE in ["MCQ", "Y/N", "MCQ_MMMU_Pro"]:
                        if listinstr(["WeMath"], dataset_name):
                            judge_kwargs["model"] = "gpt-4o-mini"
                        else:
                            judge_kwargs["model"] = "chatgpt-0125"
                    elif listinstr(
                        ["MMVet", "LLaVABench", "MMBench-Video"], dataset_name
                    ):
                        judge_kwargs["model"] = "gpt-4-turbo"
                    elif listinstr(
                        [
                            "MathVista",
                            "MathVerse",
                            "MathVision",
                            "DynaMath",
                            "VL-RewardBench",
                            "LogicVista",
                        ],
                        dataset_name,
                    ):  # noqa: E501
                        judge_kwargs["model"] = "gpt-4o-mini"
                    elif listinstr(
                        [
                            "MMLongBench",
                            "MMDU",
                            "DUDE",
                            "SLIDEVQA",
                            "MIA-Bench",
                            "WildVision",
                        ],
                        dataset_name,
                    ):  # noqa: E501
                        judge_kwargs["model"] = "gpt-4o"

                if rank == 0:
                    logger.info(judge_kwargs)

                if world_size > 1:
                    dist.barrier()

                # Only Rank 0 handles the evaluation part
                if rank == 0:
                    # Prepare Submission Files for MMMU_TEST AND MMT-Bench_ALL
                    if dataset_name in ["MMMU_TEST"]:
                        result_json = MMMU_result_transfer(result_file)
                        logger.info(
                            f"Transfer MMMU_TEST result to json for official evaluation, "
                            f"json file saved in {result_json}"
                        )
                        continue
                    elif "MMT-Bench_ALL" in dataset_name:
                        submission_file = MMTBench_result_transfer(
                            result_file, **judge_kwargs
                        )
                        logger.info(
                            f"Extract options from prediction of MMT-Bench FULL split for official evaluation "
                            f"(https://eval.ai/web/challenges/challenge-page/2328/overview), "
                            f"submission file saved in {submission_file}"
                        )
                        continue

                    # Skip the evaluation part if only infer
                    if args.mode == "infer":
                        continue

                    # Skip the evaluation part if the dataset evaluation is not supported or annotations are missing
                    if "MLLMGuard_DS" in dataset_name:
                        logger.info(
                            "The evaluation of MLLMGuard_DS is not supported yet. "
                        )
                        continue
                    elif "AesBench_TEST" == dataset_name:
                        logger.info(
                            f"The results are saved in {result_file}. "
                            f"Please send it to the AesBench Team via huangyipo@hotmail.com."
                        )
                        continue
                    elif dataset_name in [
                        "DocVQA_TEST",
                        "InfoVQA_TEST",
                        "Q-Bench1_TEST",
                        "A-Bench_TEST",
                    ]:
                        logger.info(
                            f"{dataset_name} is a test split without ground-truth. "
                            "Thus only the inference part is supported for those datasets. "
                        )
                        continue
                    elif dataset_name in [
                        "MMBench_TEST_CN",
                        "MMBench_TEST_EN",
                        "MMBench",
                        "MMBench_CN",
                        "MMBench_TEST_CN_V11",
                        "MMBench_TEST_EN_V11",
                        "MMBench_V11",
                        "MMBench_CN_V11",
                    ] and not MMBenchOfficialServer(dataset_name):
                        logger.error(
                            f"Can not evaluate {dataset_name} on non-official servers, will skip the evaluation."
                        )
                        continue

                    # Setup the proxy for the evaluation
                    eval_proxy = os.environ.get("EVAL_PROXY", None)
                    old_proxy = os.environ.get("HTTP_PROXY", "")
                    if eval_proxy is not None:
                        proxy_set(eval_proxy)

                    # Perform the Evaluation
                    eval_results = dataset.evaluate(result_file, **judge_kwargs)
                    # Display Evaluation Results in Terminal
                    if eval_results is not None:
                        assert isinstance(eval_results, dict) or isinstance(
                            eval_results, pd.DataFrame
                        )
                        logger.info(
                            f"The evaluation of model {model_name} x dataset {dataset_name} has finished! "
                        )
                        logger.info("Evaluation Results:")
                        if isinstance(eval_results, dict):
                            logger.info("\n" + json.dumps(eval_results, indent=4))
                        elif isinstance(eval_results, pd.DataFrame):
                            if len(eval_results) < len(eval_results.columns):
                                eval_results = eval_results.T
                            logger.info("\n" + tabulate(eval_results))

                    # Restore the proxy
                    if eval_proxy is not None:
                        proxy_set(old_proxy)

                    # Create the symbolic links for the prediction files
                    files = os.listdir(pred_root)
                    files = [
                        x
                        for x in files
                        if (f"{model_name}_{dataset_name}" in x or "status.json" in x)
                    ]
                    for f in files:
                        cwd = os.getcwd()
                        file_addr = osp.join(cwd, pred_root, f)
                        link_addr = osp.join(cwd, pred_root_meta, f)
                        if osp.exists(link_addr) or osp.islink(link_addr):
                            os.remove(link_addr)
                        os.symlink(file_addr, link_addr)

            except Exception as e:
                logger.exception(
                    f"Model {model_name} x Dataset {dataset_name} combination failed: {e}, "
                    "skipping this combination."
                )
                continue

    if world_size > 1:
        dist.destroy_process_group()


if __name__ == "__main__":
    load_env()
    main()
